"""
mapping references:
https://github.com/CurryTang/Graph-LLM/blob/master/utils.py
"""

from ogb.nodeproppred import PygNodePropPredDataset
import torch_geometric.transforms as T
import torch
import pandas as pd
from datasets import load_dataset
from torch_geometric.utils import to_undirected, add_remaining_self_loops

products_mapping = {'Home & Kitchen': 'Home & Kitchen',
        'Health & Personal Care': 'Health & Personal Care',
        'Beauty': 'Beauty',
        'Sports & Outdoors': 'Sports & Outdoors',
        'Books': 'Books',
        'Patio, Lawn & Garden': 'Patio, Lawn & Garden',
        'Toys & Games': 'Toys & Games',
        'CDs & Vinyl': 'CDs & Vinyl',
        'Cell Phones & Accessories': 'Cell Phones & Accessories',
        'Grocery & Gourmet Food': 'Grocery & Gourmet Food',
        'Arts, Crafts & Sewing': 'Arts, Crafts & Sewing',
        'Clothing, Shoes & Jewelry': 'Clothing, Shoes & Jewelry',
        'Electronics': 'Electronics',
        'Movies & TV': 'Movies & TV',
        'Software': 'Software',
        'Video Games': 'Video Games',
        'Automotive': 'Automotive',
        'Pet Supplies': 'Pet Supplies',
        'Office Products': 'Office Products',
        'Industrial & Scientific': 'Industrial & Scientific',
        'Musical Instruments': 'Musical Instruments',
        'Tools & Home Improvement': 'Tools & Home Improvement',
        'Magazine Subscriptions': 'Magazine Subscriptions',
        'Babycare Products': 'Babycare Products',
        'NAN': 'NAN',
        'Appliances': 'Appliances',
        'Kitchen & Dining': 'Kitchen & Dining',
        'Collectibles & Fine Art': 'Collectibles & Fine Art',
        'All Beauty': 'All Beauty',
        'Luxury Beauty': 'Luxury Beauty',
        'Amazon Fashion': 'Amazon Fashion',
        'Computers': 'Computers',
        'All Electronics': 'All Electronics',
        'Purchase Circles': 'Purchase Circles',
        'MP3 Players & Accessories': 'MP3 Players & Accessories',
        'Gift Cards': 'Gift Cards',
        'School Supplies': 'School Supplies',
        'Home Improvement': 'Home Improvement',
        'Camera & Photo': 'Camera & Photo',
        'GPS & Navigation': 'GPS & Navigation',
        'Digital Music': 'Digital Music',
        'Car Electronics': 'Car Electronics',
        'Baby': 'Baby',
        'Kindle Store': 'Kindle Store',
        'Kindle Apps': 'Kindle Apps',
        'Furniture & Decor': 'Furniture & Decor',
        'Others': 'Others'}

products_keys_list = list(products_mapping.keys())

def get_raw_text(seed=0, dataset_folder="/data/shared/zhexu/"):
    data = torch.load(f'{dataset_folder}ogbn_products/ogbn-products_subset.pt')
    org_text = pd.read_csv(f'{dataset_folder}ogbn_products_orig/ogbn-products_subset.csv')

    title = [str(x) for x in org_text["title"].tolist()]
    content = [str(x) for x in org_text["content"].tolist()]
    label = [products_mapping[products_keys_list[i]] for i in data.y.flatten().tolist()]

    data.edge_index = data.adj_t.to_symmetric()
    row, col, _ = data.edge_index.coo()
    edge_list = torch.stack([row, col], dim=0)
    data.edge_index = edge_list

    text = {'title': title, 'content': content, 'label': label}

    num_classes = 47
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)
    data.edge_index, _ = add_remaining_self_loops(data.edge_index, num_nodes=data.num_nodes)

    return data, text, products_keys_list